import os
import unittest

import torch
from torch.backends import cudnn
from torchsummary import summary

from models import build_aspp

torch.backends.cudnn.benchmark = True

os.environ["CUDA_VISIBLE_DEVICES"] = '0'


class ASPPTestCase(unittest.TestCase):

    def test_aspp(self):
        self.assertTrue(torch.cuda.is_available())
        model = build_aspp(in_channels=2048)
        model.cuda()
        input_size = (2048, 19, 19)
        summary(model, input_size=input_size)


if __name__ == '__main__':
    unittest.main()
